{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3b05af3b",
   "metadata": {},
   "source": [
    "(tune-rllib-example)=\n",
    "\n",
    "# Using RLlib with Tune\n",
    "\n",
    "```{image} /rllib/images/rllib-logo.png\n",
    ":align: center\n",
    ":alt: RLlib Logo\n",
    ":height: 120px\n",
    ":target: https://docs.ray.io\n",
    "```\n",
    "\n",
    "```{contents}\n",
    ":backlinks: none\n",
    ":local: true\n",
    "```\n",
    "\n",
    "## Example\n",
    "\n",
    "Example of using PBT with RLlib.\n",
    "\n",
    "Note that this requires a cluster with at least 8 GPUs in order for all trials\n",
    "to run concurrently, otherwise PBT will round-robin train the trials which\n",
    "is less efficient (or you can set {\"gpu\": 0} to use CPUs for SGD instead).\n",
    "\n",
    "Note that Tune in general does not need 8 GPUs, and this is just a more\n",
    "computationally demanding example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19e3c389",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "from ray import tune\n",
    "from ray.tune.schedulers import PopulationBasedTraining\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "\n",
    "    # Postprocess the perturbed config to ensure it's still valid\n",
    "    def explore(config):\n",
    "        # ensure we collect enough timesteps to do sgd\n",
    "        if config[\"train_batch_size\"] < config[\"sgd_minibatch_size\"] * 2:\n",
    "            config[\"train_batch_size\"] = config[\"sgd_minibatch_size\"] * 2\n",
    "        # ensure we run at least one sgd iter\n",
    "        if config[\"num_sgd_iter\"] < 1:\n",
    "            config[\"num_sgd_iter\"] = 1\n",
    "        return config\n",
    "\n",
    "    pbt = PopulationBasedTraining(\n",
    "        time_attr=\"time_total_s\",\n",
    "        perturbation_interval=120,\n",
    "        resample_probability=0.25,\n",
    "        # Specifies the mutations of these hyperparams\n",
    "        hyperparam_mutations={\n",
    "            \"lambda\": lambda: random.uniform(0.9, 1.0),\n",
    "            \"clip_param\": lambda: random.uniform(0.01, 0.5),\n",
    "            \"lr\": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],\n",
    "            \"num_sgd_iter\": lambda: random.randint(1, 30),\n",
    "            \"sgd_minibatch_size\": lambda: random.randint(128, 16384),\n",
    "            \"train_batch_size\": lambda: random.randint(2000, 160000),\n",
    "        },\n",
    "        custom_explore_fn=explore,\n",
    "    )\n",
    "\n",
    "    analysis = tune.run(\n",
    "        \"PPO\",\n",
    "        name=\"pbt_humanoid_test\",\n",
    "        scheduler=pbt,\n",
    "        num_samples=1,\n",
    "        metric=\"episode_reward_mean\",\n",
    "        mode=\"max\",\n",
    "        config={\n",
    "            \"env\": \"Humanoid-v1\",\n",
    "            \"kl_coeff\": 1.0,\n",
    "            \"num_workers\": 8,\n",
    "            \"num_gpus\": 0, # number of GPUs to use\n",
    "            \"model\": {\"free_log_std\": True},\n",
    "            # These params are tuned from a fixed starting value.\n",
    "            \"lambda\": 0.95,\n",
    "            \"clip_param\": 0.2,\n",
    "            \"lr\": 1e-4,\n",
    "            # These params start off randomly drawn from a set.\n",
    "            \"num_sgd_iter\": tune.choice([10, 20, 30]),\n",
    "            \"sgd_minibatch_size\": tune.choice([128, 512, 2048]),\n",
    "            \"train_batch_size\": tune.choice([10000, 20000, 40000]),\n",
    "        },\n",
    "    )\n",
    "\n",
    "    print(\"best hyperparameters: \", analysis.best_config)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fb69a24",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## More RLlib Examples\n",
    "\n",
    "- {doc}`/tune/examples/includes/pb2_ppo_example`:\n",
    "  Example of optimizing a distributed RLlib algorithm (PPO) with the PB2 scheduler.\n",
    "  Uses a small population size of 4, so can train on a laptop."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "orphan": true
 },
 "nbformat": 4,
 "nbformat_minor": 5
}